# -*- coding: utf-8 -*-
# @Author  : xinghen
# @Date    : 2020-11-25

#正式版本

import csv
from sklearn.feature_extraction import DictVectorizer
from sklearn import preprocessing
from sklearn import tree
from six import StringIO
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import pydot, pydotplus
from category_encoders import OneHotEncoder

import lightgbm

df_in = pd.read_csv("data.csv")
col_names = pd.read_csv('data.csv', nrows=0).columns
types_dict = {"井别": str, "井型": str, "大斜度井": str, "井段类型": str, "岩性": str, "渗透性": str, "流体性质": str, "温压": str,
              "下部压力": str,
              "最大井斜": float, "温度预测": float, "压力预测": float, "井眼尺寸": float}

# df_in = df_in.fillna("无",downcast=False)
# types_dict.update({col: str for col in col_names if col not in types_dict})
# for col_name in col_names:
#     if df_in[col_name].dtypes == object:
#         df_in[col_name] = df_in[col_name].astype("str")
#         print(df_in[col_name].dtypes)
df_in = df_in.set_index('井名')
df_x, df_x_test = train_test_split(df_in, test_size=0.3)

df_out = pd.read_csv("result1.csv").fillna(0)
df_y, df_y_test = train_test_split(df_out, test_size=0.3)

ohe = OneHotEncoder(cols=["井别", "井型", "大斜度井", "井段类型", "岩性", "渗透性", "流体性质", "温压", "下部压力"],
                    handle_unknown='indicator',
                    handle_missing='indicator',
                    use_cat_names=True).fit(df_x)  # 在训练集上训练

print(df_y)

# ohe.fit(df_x)
df_ohe_x = ohe.transform(df_x)  # 转换训练集

# df_ohe_x = pd.DataFrame(ohe.transform(df_x).toarray(), columns=feature_names)  # 应用规则在训练集上


# 编码与哑变量
# 将分类转换为分类数值   LabelEncoder标签专用，所以不需要是矩阵
y = df_y.iloc[:, -1]  # 想看y有多少种使用set(y)即可  有三种，另外一个是Unknow
enc = LabelEncoder()
label_y = enc.fit_transform(y)
df_y.iloc[:, -1] = label_y
enc.classes_  # 查看有多少种类别

# 创建决策树
clf = tree.DecisionTreeClassifier(criterion='entropy')  # 指明为那个算法
clf = clf.fit(df_ohe_x, df_y)
print('clf:' + str(clf))

# 直接导出为pdf树形结构
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("sz_tree.pdf")
df_ohe_x_test = ohe.transform(df_x_test)  # 应用规则在预测集上
# print(df_ohe_x_test)
# print()
predictedY = clf.predict(df_ohe_x_test)  # 对新数据进行预测

ec_y_test = enc.fit_transform(df_y_test)
count = 0
for i in range(len(ec_y_test)):
    if ec_y_test[i] == predictedY[i]:
        count += 1

print(' Y is :' + str(ec_y_test))  # 输出为predictedY:[1]，表示愿意购买，1表示yes
print('pY is :' + str(predictedY))  # 输出为predictedY:[1]，表示愿意购买，1表示yes
print("acc : " + str(count/len(ec_y_test)))
